package com.jstarcraft.ai.jsat.classifiers.linear;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;

import com.jstarcraft.ai.jsat.DataSet;
import com.jstarcraft.ai.jsat.SingleWeightVectorModel;
import com.jstarcraft.ai.jsat.classifiers.CategoricalResults;
import com.jstarcraft.ai.jsat.classifiers.ClassificationDataSet;
import com.jstarcraft.ai.jsat.classifiers.Classifier;
import com.jstarcraft.ai.jsat.classifiers.DataPoint;
import com.jstarcraft.ai.jsat.classifiers.svm.PlattSMO;
import com.jstarcraft.ai.jsat.distributions.Distribution;
import com.jstarcraft.ai.jsat.exceptions.FailedToFitException;
import com.jstarcraft.ai.jsat.linear.DenseVector;
import com.jstarcraft.ai.jsat.linear.Vec;
import com.jstarcraft.ai.jsat.lossfunctions.LogisticLoss;
import com.jstarcraft.ai.jsat.parameters.Parameterized;
import com.jstarcraft.ai.jsat.utils.ListUtils;

import it.unimi.dsi.fastutil.ints.IntArrayList;

/**
 * This provides an implementation of regularized logistic regression using Dual
 * Coordinate Descent. This algorithm works well on both dense and sparse large
 * data sets. <br>
 * <br>
 * The regularized problem is of the form:<br>
 * C <big>&Sigma;</big> log(1+exp(-y<sub>i</sub>w<sup>T</sup>x<sub>i</sub>)) +
 * w<sup>T</sup>w/2 <br>
 * <br>
 * See:<br>
 * Yu, H.-F., Huang, F.-L.,&amp;Lin, C.-J. (2010). <i>Dual Coordinate Descent
 * Methods for Logistic Regression and Maximum Entropy Models</i>. Machine
 * Learning, 85(1-2), 41–75. doi:10.1007/s10994-010-5221-8
 * 
 * @author Edward Raff
 */
public class LogisticRegressionDCD implements Classifier, Parameterized, SingleWeightVectorModel {

    private static final long serialVersionUID = -5813704270903243462L;
    private static final double eps_1 = 1e-3;
    private static final double eps_2 = 1e-8;

    private Vec w;
    private double bias;
    private boolean useBias;

    private double C;
    private int maxIterations;

    /**
     * Creates a new Logistic Regression learner that does no more than 100 training
     * iterations with a default regularization tradeoff of C = 1
     */
    public LogisticRegressionDCD() {
        this(1.0);
    }

    /**
     * Creates a new Logistic Regression learner that does no more than 100 training
     * iterations.
     * 
     * @param C the regularization tradeoff term
     */
    public LogisticRegressionDCD(double C) {
        this(C, 100);
    }

    /**
     * Creates a new Logistic Regression learner
     * 
     * @param C             the regularization tradeoff term
     * @param maxIterations the maximum number of iterations through the data set
     */
    public LogisticRegressionDCD(double C, int maxIterations) {
        setC(C);
        setMaxIterations(maxIterations);
    }

    /**
     * Copy constructor
     * 
     * @param toCopy the object to copy
     */
    protected LogisticRegressionDCD(LogisticRegressionDCD toCopy) {
        this(toCopy.C, toCopy.maxIterations);
        if (toCopy.w != null)
            this.w = toCopy.w.clone();
        this.bias = toCopy.bias;
        this.useBias = toCopy.useBias;
    }

    /**
     * Sets the regularization trade-off term. larger values reduce the amount of
     * regularization, and smaller values increase the regularization.
     * 
     * @param C the positive regularization tradeoff value
     */
    public void setC(double C) {
        if (C <= 0 || Double.isInfinite(C) || Double.isNaN(C))
            throw new IllegalArgumentException("C must be a positive constant, not " + C);
        this.C = C;
    }

    /**
     * Returns the regularization tradeoff parameter
     * 
     * @return the regularization tradeoff parameter
     */
    public double getC() {
        return C;
    }

    /**
     * Sets the maximum number of iterations the algorithm is allowed to run for.
     * 
     * @param maxIterations the maximum number of iterations
     */
    public void setMaxIterations(int maxIterations) {
        if (maxIterations < 1)
            throw new IllegalArgumentException("iterations must be a positive value, not " + maxIterations);
        this.maxIterations = maxIterations;
    }

    /**
     * Returns the maximum number of iterations the algorithm is allowed to run
     * 
     * @return the maximum number of iterations the algorithm is allowed to run
     */
    public int getMaxIterations() {
        return maxIterations;
    }

    /**
     * Sets whether or not an implicit bias term should be added to the model.
     * 
     * @param useBias {@code true} to add a bias term, {@code false} to exclude the
     *                bias term.
     */
    public void setUseBias(boolean useBias) {
        this.useBias = useBias;
    }

    /**
     * Returns {@code true} if a bias term is in use, {@code false} otherwise.
     * 
     * @return {@code true} if a bias term is in use, {@code false} otherwise.
     */
    public boolean isUseBias() {
        return useBias;
    }

    @Override
    public Vec getRawWeight() {
        return w;
    }

    @Override
    public double getBias() {
        return bias;
    }

    @Override
    public Vec getRawWeight(int index) {
        if (index < 1)
            return getRawWeight();
        else
            throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override
    public double getBias(int index) {
        if (index < 1)
            return getBias();
        else
            throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override
    public int numWeightsVecs() {
        return 1;
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        return LogisticLoss.classify(w.dot(data.getNumericalValues()) + bias);
    }

    @Override
    public void train(ClassificationDataSet dataSet, boolean parallel) {
        train(dataSet);
    }

    @Override
    public void train(ClassificationDataSet dataSet) {
        if (dataSet.getClassSize() != 2)
            throw new FailedToFitException("Logistic Regression is a binary classifier, can can not handle " + dataSet.getClassSize() + " class problems");
        final int N = dataSet.size();
        List<Vec> x = dataSet.getDataVectors();
        double[] alpha = new double[N];
        double[] alphaPrime = new double[N];
        double[] Q_ii = new double[N];
        int[] y = new int[N];

        /*
         * All points start will small eps, because LR dosn't tend to zero out
         * coefficients. But we expect a few alphas to quickly go to larger values.
         */
        Arrays.fill(alpha, Math.min(eps_1 * C, eps_2));
        Arrays.fill(alphaPrime, C - alpha[0]);
        w = new DenseVector(dataSet.getNumNumericalVars());
        bias = 0;

        for (int i = 0; i < N; i++) {
            y[i] = dataSet.getDataPointCategory(i) * 2 - 1;
            Vec x_i = x.get(i);
            Q_ii[i] = x_i.dot(x_i);
            w.mutableAdd(alpha[0] * y[i], x_i);// all alpha are the same right now
            if (useBias)
                bias += alpha[0] * y[i];
        }

        IntArrayList permutation = new IntArrayList(N);
        ListUtils.addRange(permutation, 0, N, 1);

        for (int iter = 0; iter < maxIterations; iter++) {
            Collections.shuffle(permutation);

            double maxChange = 0;

            for (int i : permutation) {
                Vec x_i = x.get(i);
                // Step 1.
                final double c1 = alpha[i], c2 = alphaPrime[i];
                double a = Q_ii[i], b = y[i] * (w.dot(x_i) + bias);
                double z_m = (c2 - c1) / 2, s = c1 + c2;
                boolean case1 = z_m >= -b / a;

                double z;// see eq (35)
                if (case1) {
                    if (c1 >= s / 2)
                        z = 0.1 * c1;
                    else
                        z = c1;
                } else {
                    if (c2 >= s / 2)
                        z = 0.1 * c2;
                    else
                        z = c2;
                }

                // what if z is very small? Leave it alone..
                if (z < 1e-20)
                    continue;

                // Step 2.
                // Algorithm 4 solving equation (18)
                // would it really take more than 100 iterations?
                for (int subIter = 0; subIter < 100; subIter++) {
                    double gP = Math.log(z / (C - z));
                    if (case1)
                        gP += a * (z - c1) + b;
                    else
                        gP += a * (z - c2) - b;
                    // check if "0"
                    if (Math.abs(gP) < 1e-6)
                        break;

                    double gPP = a + s / (z * (s - z));
                    double d = -gP / gPP;

                    if (z + d <= 0)
                        z *= 0.1;// unsepcified shrinkage term: just use 0.1
                    else
                        z += d;
                }

                // Step 4. alpha_i = Z1, alpha'_i = Z2.
                if (case1) {
                    alpha[i] = z;
                    alphaPrime[i] = C - z;
                } else {
                    alpha[i] = C - z;
                    alphaPrime[i] = z;
                }

                // Step 3. w = w + (Z1 −alpha_i) yi xi
                double change = (alpha[i] - c1);
                w.mutableAdd(change * y[i], x_i);
                if (useBias)
                    bias += change * y[i];
                maxChange = Math.max(maxChange, change);
            }

            // Convergence check
            if (Math.abs(maxChange) < 1e-4)
                return;
        }

    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    @Override
    public Classifier clone() {
        return new LogisticRegressionDCD(this);
    }

    /**
     * Returns the weight vector used to compute results via a dot product. <br>
     * Do not modify this value, or you will alter the results returned.
     * 
     * @return the learned weight vector for prediction
     */
    public Vec getWeightVec() {
        return w;
    }

    /**
     * Guess the distribution to use for the regularization term
     * {@link #setC(double) C} in Logistic Regression.
     *
     * @param d the data set to get the guess for
     * @return the guess for the C parameter
     */
    public static Distribution guessC(DataSet d) {
        return PlattSMO.guessC(d);
    }
}
